-
Notifications
You must be signed in to change notification settings - Fork 30
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix for gh-1468 in arithmetic reduction when type promotion is needed #1470
Conversation
Removed mention of dtype kwarg in usage line
Function _reduce_over_axis promotes input array to requested result data type and carries out reduction computation in that data type. This is done in dtype if implementation supports it. If implementation does not support the requested dtype, we reduce in the default_dtype, and cast to the request dtype afterwards.
View rendered docs @ https://intelpython.github.io/dpctl/pulls/1470/index.html |
Array API standard conformance tests for dpctl=0.15.1dev1=py310ha25a700_4 ran successfully. |
Array API standard conformance tests for dpctl=0.15.1dev1=py310ha25a700_3 ran successfully. |
d5438d4
to
ff9b5eb
Compare
Array API standard conformance tests for dpctl=0.15.1dev1=py310ha25a700_4 ran successfully. |
dpctl/tensor/_reduction.py
Outdated
@@ -118,7 +118,7 @@ def _reduction_over_axis( | |||
dpt.full( | |||
res_shape, | |||
_identity, | |||
dtype=_default_reduction_type_fn(inp_dt, q), | |||
dtype=dtype, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
With this change, the call to astype
is no longer necessary.
It also means that when logsumexp
or reduce_hypot
reduce over an empty axis or array (e.g., dpt.logsumexp(dpt.ones((1, 0, 1), dtype="i4"), axis=1, dtype="i4")
) you get OverflowError
instead of silently casting the identity to the output type.
For now, the astype
can be removed. I've experimented with removing this branch in #1465 too.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, and it should not be dtype
, it should be res_dt
.
Array API standard conformance tests for dpctl=0.15.1dev1=py310ha25a700_5 ran successfully. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you @oleksandr-pavlyk !
Closes gh-1468
This PR changes
_reduce_over_axis
when an implementation to reduce using requesteddtype
for given input type does not exist. The change casts input array into a temporary array of requesteddtype
and calls reduction on the temporary (silent assumption that implementation to reduce input array using the same input dtype is available).